import torch
import numpy as np
import pickle, json, time, re, sys
import networkx as nx
from multiprocessing import Pool
import dgl
from dgl import from_networkx
import dgl
from torch.nn.utils.rnn import pad_sequence
import torch.nn.functional as F
import os
from dgl.data import DGLDataset
from DG import Graph, Node


def run_one_design(design):
    
    folder_dir = f'/home/coguest5/hdl_fusion/data_collect/depend_graph/data'
    with open(f'{folder_dir}/{design}_ep_graph.pkl', 'rb') as f:
        graph = pickle.load(f)
    with open(f'{folder_dir}/{design}_ep_node_dict.pkl', 'rb') as f:
        node_dict = pickle.load(f)

    

    with open (f"/home/coguest5/hdl_fusion/data_collect/label/ep_lst/{design}.json", 'r') as f:
        reg_lst = json.load(f)
    reg_set = set(reg_lst)
    g_nx = nx.DiGraph(graph)

    node2int_dct = {}
    for ep in g_nx.nodes():
        node2int_dct[ep] = len(node2int_dct)


    node_lst, edge_lst = [], []
    slack_lst = []

    # for ep in reg_lst:
    #     if (not os.path.exists(f"/home/coguest5/hdl_fusion/text_enc/llm_extra/rtl_func_ori/{design}/{ep}.txt")) \
    #         or (not os.path.exists(f"/home/coguest5/hdl_fusion/text_enc/llm_extra/rtl_func_pos/{design}/{ep}.txt")):
    #         if ep in g_nx.nodes():
    #             g_nx.remove_node(ep)

    for ep in reg_lst:
        if not os.path.exists(f"/home/coguest5/hdl_fusion/text_enc/llm_extra/rtl_func_ori/{design}/{ep}.txt"):
            continue
        if not os.path.exists(f"/home/coguest5/hdl_fusion/text_enc/llm_extra/rtl_func_pos/{design}/{ep}.txt"):
            continue
        if ep not in g_nx.nodes():
            print(ep)
            assert False

        ### node
        node_lst.append(node2int_dct[ep])
        ### edge
        for edge in g_nx.edges(ep):
            edge_lst.append((node2int_dct[ep], node2int_dct[edge[1]]))
        ### slack label
        with open (f"/home/coguest5/hdl_fusion/data_collect/label/ppa/cone_pwr_area/{design}/{ep}.json", 'r') as f:
            cone_ppa_dct = json.load(f)
        slack_label = cone_ppa_dct['slack']
        slack_lst.append(slack_label)

    # print(len(node_lst), len(edge_lst), len(slack_lst))
    # assert len(node_lst) == len(slack_lst)
    # return

    print(len(slack_lst))
    return
    
    ### design ppa label
    with open(f"/home/coguest5/hdl_fusion/data_collect/label/ppa/json/{design}/ppa.json", 'r') as f:
        ppa_dct = json.load(f)
    
    src_nodes, dst_nodes = zip(*edge_lst)
    src_nodes, dst_nodes = list(src_nodes), list(dst_nodes)
    src_nodes_new, dst_nodes_new = [], []
    for i in range(len(src_nodes)):
        scr_n = src_nodes[i]
        dst_n = dst_nodes[i]
        if (scr_n in node_lst) and (dst_n in node_lst):
            src_nodes_new.append(scr_n)
            dst_nodes_new.append(dst_n)

    
    
    print(len(src_nodes), len(dst_nodes))
    print(len(src_nodes_new), len(dst_nodes_new))
    node_set = set(src_nodes) | set(dst_nodes)
    node_set_new = set(src_nodes_new) | set(dst_nodes_new)
    print(len(node_set), len(node_set_new))
    return

    dgl_graph = dgl.graph((src_nodes_new, dst_nodes_new))

    # dgl_graph = dgl.graph((node_lst, edge_lst))
    dgl_graph.ndata['label_slack'] = torch.tensor(slack_lst)
    # dgl_graph.ndata['label_area'] = torch.tensor([ppa_dct['area']])
    # dgl_graph.ndata['label_power'] = torch.tensor(ppa_dct['power'])
    # dgl_graph.ndata['label_wns'] = torch.tensor(ppa_dct['wns'])
    # dgl_graph.ndata['label_tns'] = torch.tensor(ppa_dct['tns'])


        

    print(dgl_graph)
        

    
    return


    dgl_graph = from_networkx(g_nx)

    dgl_graph.ndata['feat'] = feat_matrix
    dgl_graph.edata['feat'] = feat_matrix_edge
    spd, path = dgl.shortest_dist(dgl_graph, root=None, return_paths=True)
    dgl_graph.ndata['spd'] = spd
    dgl_graph.ndata['path'] = path
    
    dataset.add_graph_data(dgl_graph, [])

    data_dict[len(data_dict)] = [design, ep]


def save_dataset_one_design(design_lst):
    
    for design in design_lst:
        print('Current Design: ', design)
        run_one_design(design)



if __name__ == '__main__':
    global cmd
    cmd = "ori"


    with open ("../dataset_js/design_all.json", 'r') as f:
        design_lst = json.load(f)



    save_dataset_one_design(design_lst)
